import os
import os.path as osp

import gym

import diffgro
from diffgro.environments import make_env as _make_env
from diffgro.common.evaluations import evaluate
from diffgro.utils.config import load_config, save_config
from diffgro.utils import Parser, make_dir, print_r, print_y, print_b
from train import *


def train(args):
    # 0. Make Dummy Environment
    domain_name, task_name = args.env_name.split(".")
    if domain_name == 'metaworld':
        env = _make_env('metaworld', 'push-variant-v2')
    if domain_name == 'metaworld_complex':
        env = _make_env('metaworld_complex', 'puck-drawer-button-stick-variant-v2')
    print_y(f"Obs Space: {env.observation_space.shape}, Act Space: {env.action_space.shape}")

    # 1. Save Path
    if args.phase == 0:
        save_path = osp.join("./results/diffbc", domain_name, task_name, args.tag)
    elif args.phase == 1:
        save_path = osp.join("./results/diffbc", domain_name, domain_name, args.tag)
    else:
        raise NotImplementedError
    
    # 2. Make Buffer
    buff, task_list = make_buff(args, env)
    num_task = len(task_list)
    print_r(f"Number of tasks {num_task}")

    if args.train:
        # 3. Load Config
        config = load_config('./config/algos/diffbc.yml', domain_name)
        config['planner']['params']['batch_size'] = \
            config['planner']['params']['batch_size'] * num_task if num_task > 1 else 64
        config['planner']['training']['total_timesteps'] = \
            1_000_000 if num_task > 1 else 200_000
        config['planner']['params']['seed'] = args.seed
        config['planner']['datasets'] = args.dataset_path

        # 4. Make Models
        model_path = save_path + "/planner"
        model = diffgro.DiffBCPlanner(
            env=env,
            replay_buffer=buff,
            **config['planner']['params'],
            verbose=1,
            policy_kwargs=config['planner']['policy_kwargs'],
        )
    
        # 4. Training & Evaluation
        make_dir(save_path)
        save_config(save_path, config) # save configs
        model.learn(**config['planner']['training'])
        model.save(path=model_path)
    if args.test:
        planner = diffgro.DiffBCPlanner.load(save_path + "/planner")
        if task_name != "all": task_list = [task_name] # evaluate one task
        tot_success = []
        for task in task_list:
            args.env_name = f"{domain_name}.{task}"
            env, domain_name, task_name = make_env(args)
            model = diffgro.DiffBC(
                env, 
                planner, 
                verbose=args.verbose
            )
            success = evaluate(model, env, domain_name, task_name, args.n_episodes, True, args.video, save_path)
            tot_success.extend(success)
    
        if len(task_list) > 1:
            eval_save(tot_success, save_path)


if __name__ == "__main__":
    args = Parser("train").parse_args()
    train(args)
